{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create a Kubeflow Pipeline with BERT and Amazon SageMaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install sagemaker==1.72.0\n",
    "!pip install https://storage.googleapis.com/ml-pipeline/release/0.1.29/kfp.tar.gz --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install awscli==1.18.140"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Note: Ignore all pip install warning or errors above_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Restart the kernel to pick up pip installed libraries\n",
    "from IPython.core.display import HTML\n",
    "\n",
    "HTML(\"<script>Jupyter.notebook.kernel.restart()</script>\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Environment Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "aws_region_as_slist=!curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone | sed 's/\\(.*\\)[a-z]/\\1/'\n",
    "region = aws_region_as_slist.s\n",
    "print('Region: {}'.format(region))\n",
    "\n",
    "account_id=boto3.client('sts').get_caller_identity().get('Account')\n",
    "print('Account ID: {}'.format(account_id))\n",
    "\n",
    "bucket='sagemaker-{}-{}'.format(region, account_id)\n",
    "print('S3 Bucket: {}'.format(bucket))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iam_roles = boto3.client(\"iam\").list_roles()[\"Roles\"]\n",
    "\n",
    "for iam_role in iam_roles:\n",
    "    if \"SageMakerExecutionRole\" in iam_role[\"RoleName\"]:\n",
    "        role = iam_role[\"Arn\"]\n",
    "        break\n",
    "print(\"Role: {}\".format(role))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Copy Data from Public S3 to Private S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_public_path_tsv = \"s3://amazon-reviews-pds/tsv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_private_path_tsv = \"s3://{}/amazon-reviews-pds/tsv\".format(bucket)\n",
    "print(s3_private_path_tsv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp --recursive $s3_public_path_tsv/ $s3_private_path_tsv/ --exclude \"*\" --include \"amazon_reviews_us_Digital_Software_v1_00.tsv.gz\"\n",
    "!aws s3 cp --recursive $s3_public_path_tsv/ $s3_private_path_tsv/ --exclude \"*\" --include \"amazon_reviews_us_Digital_Video_Games_v1_00.tsv.gz\"\n",
    "!aws s3 cp --recursive $s3_public_path_tsv/ $s3_private_path_tsv/ --exclude \"*\" --include \"amazon_reviews_us_Gift_Card_v1_00.tsv.gz\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_input_data_s3_uri = \"s3://{}/amazon-reviews-pds/tsv/\".format(bucket)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Amazon SageMaker Components for Kubeflow Pipelines: \n",
    "https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker\n",
    "\n",
    "https://docs.aws.amazon.com/sagemaker/latest/dg/usingamazon-sagemaker-components.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import kfp\n",
    "from kfp import components\n",
    "from kfp import dsl\n",
    "from kfp.aws import use_aws_secret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_process_op = components.load_component_from_url(\n",
    "    \"https://raw.githubusercontent.com/kubeflow/pipelines/3ebd075212e0a761b982880707ec497c36a99d80/components/aws/sagemaker/process/component.yaml\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_train_op = components.load_component_from_url(\n",
    "    \"https://raw.githubusercontent.com/kubeflow/pipelines/3ebd075212e0a761b982880707ec497c36a99d80/components/aws/sagemaker/train/component.yaml\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_model_op = components.load_component_from_url(\n",
    "    \"https://raw.githubusercontent.com/kubeflow/pipelines/3ebd075212e0a761b982880707ec497c36a99d80/components/aws/sagemaker/model/component.yaml\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_deploy_op = components.load_component_from_url(\n",
    "    \"https://raw.githubusercontent.com/kubeflow/pipelines/3ebd075212e0a761b982880707ec497c36a99d80/components/aws/sagemaker/deploy/component.yaml\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Pre-Processing Code "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "processing_code_s3_uri = \"s3://{}/processing_code/preprocess-scikit-text-to-bert-feature-store.py\".format(bucket)\n",
    "print(processing_code_s3_uri)\n",
    "\n",
    "!aws s3 cp ./preprocess-scikit-text-to-bert-feature-store.py $processing_code_s3_uri"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Package and Upload Training Code to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar -cvzf sourcedir.tar.gz -C ./code ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_code_s3_uri = \"s3://{}/training_code/sourcedir.tar.gz\".format(bucket)\n",
    "print(training_code_s3_uri)\n",
    "\n",
    "!aws s3 cp sourcedir.tar.gz $training_code_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def processing_input(input_name, s3_uri, local_path, s3_data_distribution_type):\n",
    "    return {\n",
    "        \"InputName\": input_name,\n",
    "        \"S3Input\": {\n",
    "            \"LocalPath\": local_path,\n",
    "            \"S3Uri\": s3_uri,\n",
    "            \"S3DataType\": \"S3Prefix\",\n",
    "            \"S3DataDistributionType\": s3_data_distribution_type,\n",
    "            \"S3InputMode\": \"File\",\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "def processing_output(output_name, s3_uri, local_path, s3_upload_mode):\n",
    "    return {\n",
    "        \"OutputName\": output_name,\n",
    "        \"S3Output\": {\"LocalPath\": local_path, \"S3Uri\": s3_uri, \"S3UploadMode\": s3_upload_mode},\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_input(input_name, s3_uri, s3_data_distribution_type):\n",
    "    return {\n",
    "        \"ChannelName\": input_name,\n",
    "        \"DataSource\": {\n",
    "            \"S3DataSource\": {\n",
    "                \"S3Uri\": s3_uri,\n",
    "                \"S3DataType\": \"S3Prefix\",\n",
    "                \"S3DataDistributionType\": s3_data_distribution_type,\n",
    "            }\n",
    "        },\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dsl.pipeline(\n",
    "    name=\"BERT Pipeline\",\n",
    "    description=\"BERT Pipeline\",\n",
    ")\n",
    "def bert_pipeline(role=role, bucket=bucket, region=region, raw_input_data_s3_uri=raw_input_data_s3_uri):\n",
    "\n",
    "    import time\n",
    "    import json\n",
    "\n",
    "    pipeline_name = \"kubeflow-pipeline-sagemaker-{}\".format(int(time.time()))\n",
    "\n",
    "    network_isolation = False\n",
    "\n",
    "    ########################\n",
    "    # FEATURE ENGINEERING\n",
    "    ########################\n",
    "\n",
    "    max_seq_length = 64\n",
    "    train_split_percentage = 0.90\n",
    "    validation_split_percentage = 0.05\n",
    "    test_split_percentage = 0.05\n",
    "    balance_dataset = True\n",
    "\n",
    "    processed_train_data_s3_uri = \"s3://{}/{}/processing/output/bert-train\".format(bucket, pipeline_name)\n",
    "    processed_validation_data_s3_uri = \"s3://{}/{}/processing/output/bert-validation\".format(bucket, pipeline_name)\n",
    "    processed_test_data_s3_uri = \"s3://{}/{}/processing/output/bert-test\".format(bucket, pipeline_name)\n",
    "\n",
    "    processing_instance_type = \"ml.c5.2xlarge\"\n",
    "    processing_instance_count = 2\n",
    "\n",
    "    timestamp = int(time.time())\n",
    "\n",
    "    feature_store_offline_prefix = \"reviews-feature-store-\" + str(timestamp)\n",
    "    feature_group_name = \"reviews-feature-group-\" + str(timestamp)\n",
    "\n",
    "    # hard-coding to avoid the wrong ECR account id with create_image_uri()\n",
    "    processing_image = \"683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3\"\n",
    "    #    import sagemaker\n",
    "    #    processing_image = sagemaker.fw_utils.create_image_uri(framework='scikit-learn',\n",
    "    #                                                           framework_version='0.23-1',\n",
    "    #                                                           py_version='py3',\n",
    "    #                                                           instance_type='ml.c5.9xlarge',\n",
    "    #                                                           region='us-east-1') # hard-coding to avoid serialization issue\n",
    "\n",
    "    process = sagemaker_process_op(\n",
    "        role=role,\n",
    "        region=region,\n",
    "        image=processing_image,\n",
    "        network_isolation=network_isolation,\n",
    "        instance_type=processing_instance_type,\n",
    "        instance_count=processing_instance_count,\n",
    "        container_arguments=[\n",
    "            \"--train-split-percentage\",\n",
    "            str(train_split_percentage),\n",
    "            \"--validation-split-percentage\",\n",
    "            str(validation_split_percentage),\n",
    "            \"--test-split-percentage\",\n",
    "            str(test_split_percentage),\n",
    "            \"--max-seq-length\",\n",
    "            str(max_seq_length),\n",
    "            \"--balance-dataset\",\n",
    "            str(balance_dataset),\n",
    "            \"--feature-store-offline-prefix\",\n",
    "            str(feature_store_offline_prefix),\n",
    "            \"--feature-group-name\",\n",
    "            str(feature_group_name),\n",
    "        ],\n",
    "        environment={\"AWS_DEFAULT_REGION\": \"us-east-1\"},  # hard-coding to avoid serialization issue\n",
    "        container_entrypoint=[\n",
    "            \"python3\",\n",
    "            \"/opt/ml/processing/input/code/preprocess-scikit-text-to-bert-feature-store.py\",\n",
    "        ],\n",
    "        input_config=[\n",
    "            processing_input(\n",
    "                input_name=\"raw-input-data\",\n",
    "                s3_uri=\"{}\".format(raw_input_data_s3_uri),\n",
    "                local_path=\"/opt/ml/processing/input/data/\",\n",
    "                s3_data_distribution_type=\"ShardedByS3Key\",\n",
    "            ),\n",
    "            processing_input(\n",
    "                input_name=\"code\",\n",
    "                s3_uri=\"{}\".format(processing_code_s3_uri),\n",
    "                local_path=\"/opt/ml/processing/input/code\",\n",
    "                s3_data_distribution_type=\"FullyReplicated\",\n",
    "            ),\n",
    "        ],\n",
    "        output_config=[\n",
    "            processing_output(\n",
    "                output_name=\"bert-train\",\n",
    "                s3_uri=\"{}\".format(processed_train_data_s3_uri),\n",
    "                local_path=\"/opt/ml/processing/output/bert/train\",\n",
    "                s3_upload_mode=\"EndOfJob\",\n",
    "            ),\n",
    "            processing_output(\n",
    "                output_name=\"bert-validation\",\n",
    "                s3_uri=\"{}\".format(processed_validation_data_s3_uri),\n",
    "                local_path=\"/opt/ml/processing/output/bert/validation\",\n",
    "                s3_upload_mode=\"EndOfJob\",\n",
    "            ),\n",
    "            processing_output(\n",
    "                output_name=\"bert-test\",\n",
    "                s3_uri=\"{}\".format(processed_test_data_s3_uri),\n",
    "                local_path=\"/opt/ml/processing/output/bert/test\",\n",
    "                s3_upload_mode=\"EndOfJob\",\n",
    "            ),\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    ########################\n",
    "    # TRAIN\n",
    "    ########################\n",
    "\n",
    "    train_channels = [\n",
    "        training_input(\n",
    "            input_name=\"train\", s3_uri=processed_train_data_s3_uri, s3_data_distribution_type=\"ShardedByS3Key\"\n",
    "        ),\n",
    "        training_input(\n",
    "            input_name=\"validation\",\n",
    "            s3_uri=processed_validation_data_s3_uri,\n",
    "            s3_data_distribution_type=\"ShardedByS3Key\",\n",
    "        ),\n",
    "        training_input(\n",
    "            input_name=\"test\", s3_uri=processed_test_data_s3_uri, s3_data_distribution_type=\"ShardedByS3Key\"\n",
    "        ),\n",
    "    ]\n",
    "\n",
    "    epochs = 1\n",
    "    learning_rate = 0.00001\n",
    "    epsilon = 0.00000001\n",
    "    train_batch_size = 128\n",
    "    validation_batch_size = 128\n",
    "    test_batch_size = 128\n",
    "    train_steps_per_epoch = 100\n",
    "    validation_steps = 100\n",
    "    test_steps = 100\n",
    "    train_volume_size = 1024\n",
    "    use_xla = True\n",
    "    use_amp = True\n",
    "    freeze_bert_layer = False\n",
    "    enable_sagemaker_debugger = False\n",
    "    enable_checkpointing = False\n",
    "    enable_tensorboard = False\n",
    "    input_mode = \"File\"\n",
    "    run_validation = True\n",
    "    run_test = True\n",
    "    run_sample_predictions = True\n",
    "\n",
    "    train_instance_type = \"ml.c5.9xlarge\"\n",
    "    train_instance_count = 1\n",
    "\n",
    "    train_output_location = \"s3://{}/{}/output\".format(bucket, pipeline_name)\n",
    "\n",
    "    hyperparameters = {\n",
    "        \"epochs\": \"{}\".format(epochs),\n",
    "        \"learning_rate\": \"{}\".format(learning_rate),\n",
    "        \"epsilon\": \"{}\".format(epsilon),\n",
    "        \"train_batch_size\": \"{}\".format(train_batch_size),\n",
    "        \"validation_batch_size\": \"{}\".format(validation_batch_size),\n",
    "        \"test_batch_size\": \"{}\".format(test_batch_size),\n",
    "        \"train_steps_per_epoch\": \"{}\".format(train_steps_per_epoch),\n",
    "        \"validation_steps\": \"{}\".format(validation_steps),\n",
    "        \"test_steps\": \"{}\".format(test_steps),\n",
    "        \"use_xla\": \"{}\".format(use_xla),\n",
    "        \"use_amp\": \"{}\".format(use_amp),\n",
    "        \"max_seq_length\": \"{}\".format(max_seq_length),\n",
    "        \"freeze_bert_layer\": \"{}\".format(freeze_bert_layer),\n",
    "        \"enable_sagemaker_debugger\": \"{}\".format(enable_sagemaker_debugger),\n",
    "        \"enable_checkpointing\": \"{}\".format(enable_checkpointing),\n",
    "        \"enable_tensorboard\": \"{}\".format(enable_tensorboard),\n",
    "        \"run_validation\": \"{}\".format(run_validation),\n",
    "        \"run_test\": \"{}\".format(run_test),\n",
    "        \"run_sample_predictions\": \"{}\".format(run_sample_predictions),\n",
    "        \"model_dir\": \"{}\".format(train_output_location),\n",
    "        \"sagemaker_program\": \"tf_bert_reviews.py\",\n",
    "        \"sagemaker_region\": \"{}\".format(region),\n",
    "        \"sagemaker_submit_directory\": training_code_s3_uri,\n",
    "    }\n",
    "    hyperparameters_json = json.dumps(hyperparameters)\n",
    "\n",
    "    # metric_definitions='{\"val_acc\": \"val_accuracy: ([0-9\\\\\\\\.]+)\"}',\n",
    "    metrics_definitions = [\n",
    "        {\"Name\": \"train:loss\", \"Regex\": \"loss: ([0-9\\\\.]+)\"},\n",
    "        {\"Name\": \"train:accuracy\", \"Regex\": \"accuracy: ([0-9\\\\.]+)\"},\n",
    "        {\"Name\": \"validation:loss\", \"Regex\": \"val_loss: ([0-9\\\\.]+)\"},\n",
    "        {\"Name\": \"validation:accuracy\", \"Regex\": \"val_accuracy: ([0-9\\\\.]+)\"},\n",
    "    ]\n",
    "    metrics_definitions_json = json.dumps(metrics_definitions)\n",
    "    print(metrics_definitions_json)\n",
    "\n",
    "    # .after(process) is explicitly appended below\n",
    "    train_image = \"763104351884.dkr.ecr.{}.amazonaws.com/tensorflow-training:2.3.1-cpu-py37-ubuntu18.04\".format(region)\n",
    "    training = sagemaker_train_op(\n",
    "        region=region,\n",
    "        image=train_image,\n",
    "        network_isolation=network_isolation,\n",
    "        instance_type=train_instance_type,\n",
    "        instance_count=train_instance_count,\n",
    "        hyperparameters=hyperparameters_json,\n",
    "        training_input_mode=input_mode,\n",
    "        channels=train_channels,\n",
    "        model_artifact_path=train_output_location,\n",
    "        # metric_definitions=metrics_definitions_json,\n",
    "        # TODO:  Add rules\n",
    "        role=role,\n",
    "    ).after(process)\n",
    "\n",
    "    ########################\n",
    "    # DEPLOY\n",
    "    ########################\n",
    "\n",
    "    # .after(training) is implied because we depend on training.outputs[]\n",
    "    serve_image = \"763104351884.dkr.ecr.{}.amazonaws.com/tensorflow-inference:2.3.1-cpu\".format(region)\n",
    "    create_model = sagemaker_model_op(\n",
    "        region=region,\n",
    "        model_name=training.outputs[\"job_name\"],\n",
    "        image=serve_image,\n",
    "        network_isolation=network_isolation,\n",
    "        model_artifact_url=training.outputs[\"model_artifact_url\"],\n",
    "        role=role,\n",
    "    )\n",
    "\n",
    "    deploy_instance_type = \"ml.c5.9xlarge\"\n",
    "    deploy_instance_count = 1\n",
    "\n",
    "    # .after(create_model) is implied because we depend on create_model.outputs\n",
    "    deploy_model = sagemaker_deploy_op(\n",
    "        region=region,\n",
    "        variant_name_1=\"AllTraffic\",\n",
    "        model_name_1=create_model.output,\n",
    "        instance_type_1=deploy_instance_type,\n",
    "        initial_instance_count_1=deploy_instance_count,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compile Kubeflow Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kfp.compiler.Compiler().compile(bert_pipeline, \"bert-pipeline.zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al ./bert-pipeline.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -o ./bert-pipeline.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pygmentize pipeline.yaml"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Launch Pipeline on Kubernetes Cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "client = kfp.Client()\n",
    "\n",
    "experiment = client.create_experiment(name=\"kubeflow\")\n",
    "\n",
    "my_run = client.run_pipeline(experiment.id, \"bert-pipeline\", \"bert-pipeline.zip\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Review Training Job\n",
    "\n",
    "_Note:  The above training job may take 5-10 minutes.  Please be patient._\n",
    "\n",
    "In the meantime, open the SageMaker Console to monitor the progress of your training job.\n",
    "\n",
    "![SageMaker Training Job Console](img/sagemaker-training-job-console.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Review the Endpoint\n",
    "First, we need to get the endpoint name of our newly-deployed SageMaker Prediction Endpoint."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Open AWS console and enter SageMaker service, find the endpoint name as the following picture shows.\n",
    "\n",
    "![download-pipeline](img/sm-endpoint.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make a Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## _YOU MUST COPY/PASTE THE `endpoint_name` BEFORE CONTINUING_\n",
    "Make sure to include preserve the single-quotes as shown below.\n",
    "\n",
    "![](img/sm-endpoint-kubeflow.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "sm_runtime = boto3.Session(region_name=region).client(\"sagemaker-runtime\")\n",
    "\n",
    "endpoint_name = \"<COPY-AND-PASTE-FROM-KUBEFLOW-PIPELINE-LOGS>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "inputs = [{\"features\": [\"This is great!\"]}, {\"features\": [\"This is bad.\"]}]\n",
    "\n",
    "response = sm_runtime.invoke_endpoint(\n",
    "    EndpointName=endpoint_name,\n",
    "    ContentType=\"application/jsonlines\",\n",
    "    Accept=\"application/jsonlines\",\n",
    "    Body=json.dumps(inputs).encode(\"utf-8\"),\n",
    ")\n",
    "print(\"response: {}\".format(response))\n",
    "\n",
    "predicted_classes_str = response[\"Body\"].read().decode()\n",
    "predicted_classes_json = json.loads(predicted_classes_str)\n",
    "\n",
    "predicted_classes = predicted_classes_json.splitlines()\n",
    "print(\"predicted_classes: {}\".format(predicted_classes))\n",
    "\n",
    "for predicted_class_json, input_data in zip(predicted_classes, inputs):\n",
    "    predicted_class = json.loads(predicted_class_json)[\"predicted_label\"]\n",
    "    print('Predicted star_rating: {} for review_body \"{}\"'.format(predicted_class, input_data[\"features\"][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
